2D Clusterless Decoding
2D Clusterless Decoding¶
In [ ]:
Copied!
%reload_ext autoreload
%autoreload 2
%reload_ext autoreload
%autoreload 2
In [2]:
Copied!
# ignore datajoint+jupyter async warnings
import warnings
warnings.simplefilter("ignore", category=DeprecationWarning)
warnings.simplefilter("ignore", category=ResourceWarning)
# ignore datajoint+jupyter async warnings
import warnings
warnings.simplefilter("ignore", category=DeprecationWarning)
warnings.simplefilter("ignore", category=ResourceWarning)
In [3]:
Copied!
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import logging
FORMAT = "%(asctime)s %(message)s"
logging.basicConfig(level="INFO", format=FORMAT, datefmt="%d-%b-%y %H:%M:%S")
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import logging
FORMAT = "%(asctime)s %(message)s"
logging.basicConfig(level="INFO", format=FORMAT, datefmt="%d-%b-%y %H:%M:%S")
In [4]:
Copied!
nwb_copy_file_name = "chimi20200216_new_.nwb"
nwb_copy_file_name = "chimi20200216_new_.nwb"
In [5]:
Copied!
from spyglass.decoding.clusterless import UnitMarksIndicator
marks = (
UnitMarksIndicator
& {
"nwb_file_name": nwb_copy_file_name,
"sort_interval_name": "runs_noPrePostTrialTimes raw data valid times",
"filter_parameter_set_name": "franklab_default_hippocampus",
"unit_inclusion_param_name": "all2",
"mark_param_name": "default",
"interval_list_name": "pos 1 valid times",
"sampling_rate": 500,
}
).fetch_xarray()
marks
from spyglass.decoding.clusterless import UnitMarksIndicator
marks = (
UnitMarksIndicator
& {
"nwb_file_name": nwb_copy_file_name,
"sort_interval_name": "runs_noPrePostTrialTimes raw data valid times",
"filter_parameter_set_name": "franklab_default_hippocampus",
"unit_inclusion_param_name": "all2",
"mark_param_name": "default",
"interval_list_name": "pos 1 valid times",
"sampling_rate": 500,
}
).fetch_xarray()
marks
/home/edeno/miniconda3/envs/spyglass/lib/python3.8/site-packages/seaborn/rcmod.py:82: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead. if LooseVersion(mpl.__version__) >= "3.0": /home/edeno/miniconda3/envs/spyglass/lib/python3.8/site-packages/setuptools/_distutils/version.py:346: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead. other = LooseVersion(other) 13-Sep-22 15:42:32 Connected edeno@lmf-db.cin.ucsf.edu:3306
Connecting edeno@lmf-db.cin.ucsf.edu:3306
/stelmo/nwb/analysis/chimi20200216_new_7M0E8ERPE7.nwb /stelmo/nwb/analysis/chimi20200216_new_6WW86B509M.nwb /stelmo/nwb/analysis/chimi20200216_new_TLD0MCIC5H.nwb /stelmo/nwb/analysis/chimi20200216_new_7BEQDOTX3E.nwb /stelmo/nwb/analysis/chimi20200216_new_F8QVNUMVJS.nwb /stelmo/nwb/analysis/chimi20200216_new_BVZKYWREUE.nwb /stelmo/nwb/analysis/chimi20200216_new_3HMJON557D.nwb /stelmo/nwb/analysis/chimi20200216_new_QGMZ5ESFVA.nwb /stelmo/nwb/analysis/chimi20200216_new_1KRVBBCP2N.nwb /stelmo/nwb/analysis/chimi20200216_new_9E2Z0R6TLO.nwb /stelmo/nwb/analysis/chimi20200216_new_ALRF0STB1P.nwb /stelmo/nwb/analysis/chimi20200216_new_F2TDZW8LRY.nwb /stelmo/nwb/analysis/chimi20200216_new_LTEU71Z21T.nwb /stelmo/nwb/analysis/chimi20200216_new_KT4E4LIYAI.nwb /stelmo/nwb/analysis/chimi20200216_new_KOIRLX6R6X.nwb /stelmo/nwb/analysis/chimi20200216_new_4S01EA6NVN.nwb /stelmo/nwb/analysis/chimi20200216_new_ATQO860QOB.nwb /stelmo/nwb/analysis/chimi20200216_new_H3E2HYMEJA.nwb /stelmo/nwb/analysis/chimi20200216_new_4KJ4XVBKW3.nwb /stelmo/nwb/analysis/chimi20200216_new_0V98T6HQHX.nwb /stelmo/nwb/analysis/chimi20200216_new_A5FBXFDZMD.nwb /stelmo/nwb/analysis/chimi20200216_new_A5ELOH1L7Y.nwb
Out[5]:
<xarray.DataArray (time: 655645, marks: 4, electrodes: 22)>
array([[[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]],
[[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]],
[[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]],
...,
[[ -99., nan, nan, ..., nan, nan, nan],
[-100., nan, nan, ..., nan, nan, nan],
[ -94., nan, nan, ..., nan, nan, nan],
[-104., nan, nan, ..., nan, nan, nan]],
[[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]],
[[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan]]])
Coordinates:
* time (time) float64 1.582e+09 1.582e+09 ... 1.582e+09 1.582e+09
* marks (marks) <U14 'amplitude_0000' ... 'amplitude_0003'
* electrodes (electrodes) int64 0 1 2 3 5 6 7 8 9 ... 15 16 17 18 19 21 22 23In [6]:
Copied!
plt.scatter(
marks.isel(electrodes=0).dropna("time").isel(marks=0),
marks.isel(electrodes=0).dropna("time").isel(marks=1),
s=1,
)
plt.scatter(
marks.isel(electrodes=0).dropna("time").isel(marks=0),
marks.isel(electrodes=0).dropna("time").isel(marks=1),
s=1,
)
Out[6]:
<matplotlib.collections.PathCollection at 0x7f2ad72ae3a0>
In [7]:
Copied!
from spyglass.common.common_position import IntervalPositionInfo
position_key = {
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 1 valid times",
"position_info_param_name": "default_decoding",
}
position_info = (IntervalPositionInfo() & position_key).fetch1_dataframe()
position_info
from spyglass.common.common_position import IntervalPositionInfo
position_key = {
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 1 valid times",
"position_info_param_name": "default_decoding",
}
position_info = (IntervalPositionInfo() & position_key).fetch1_dataframe()
position_info
/stelmo/nwb/analysis/chimi20200216_new_6YC9LPAR7S.nwb
Out[7]:
| head_position_x | head_position_y | head_orientation | head_velocity_x | head_velocity_y | head_speed | |
|---|---|---|---|---|---|---|
| time | ||||||
| 1.581887e+09 | 91.051650 | 211.127050 | 2.680048 | 1.741550 | 2.301478 | 2.886139 |
| 1.581887e+09 | 91.039455 | 211.144123 | 3.003241 | 1.827555 | 2.333931 | 2.964320 |
| 1.581887e+09 | 91.027260 | 211.161196 | 3.008398 | 1.915800 | 2.366668 | 3.044898 |
| 1.581887e+09 | 91.015065 | 211.178268 | 3.012802 | 2.006286 | 2.399705 | 3.127901 |
| 1.581887e+09 | 91.002871 | 211.195341 | 3.017242 | 2.099012 | 2.433059 | 3.213352 |
| ... | ... | ... | ... | ... | ... | ... |
| 1.581888e+09 | 182.158583 | 201.299625 | -0.944304 | 0.057520 | -0.356012 | 0.360629 |
| 1.581888e+09 | 182.158583 | 201.296373 | -0.942329 | 0.053954 | -0.356343 | 0.360404 |
| 1.581888e+09 | 182.158583 | 201.293121 | -0.940357 | 0.050477 | -0.356407 | 0.359964 |
| 1.581888e+09 | 182.158583 | 201.289869 | -0.953059 | 0.047091 | -0.356212 | 0.359312 |
| 1.581888e+09 | 182.158583 | 201.286617 | -0.588081 | 0.043796 | -0.355764 | 0.358450 |
655645 rows × 6 columns
In [8]:
Copied!
plt.plot(position_info.head_position_x, position_info.head_position_y)
plt.plot(position_info.head_position_x, position_info.head_position_y)
Out[8]:
[<matplotlib.lines.Line2D at 0x7f2ad70abfd0>]
In [9]:
Copied!
position_info.shape, marks.shape
position_info.shape, marks.shape
Out[9]:
((655645, 6), (655645, 4, 22))
In [10]:
Copied!
from spyglass.common.common_interval import interval_list_intersect
from spyglass.common import IntervalList
key = {}
key["interval_list_name"] = "02_r1"
key["nwb_file_name"] = nwb_copy_file_name
interval = (
IntervalList
& {
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": key["interval_list_name"],
}
).fetch1("valid_times")
valid_ephys_times = (
IntervalList
& {
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": "raw data valid times",
}
).fetch1("valid_times")
position_interval_names = (
IntervalPositionInfo
& {
"nwb_file_name": key["nwb_file_name"],
"position_info_param_name": "default_decoding",
}
).fetch("interval_list_name")
valid_pos_times = [
(
IntervalList
& {
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": pos_interval_name,
}
).fetch1("valid_times")
for pos_interval_name in position_interval_names
]
intersect_interval = interval_list_intersect(
interval_list_intersect(interval, valid_ephys_times), valid_pos_times[0]
)
valid_time_slice = slice(intersect_interval[0][0], intersect_interval[0][1])
valid_time_slice
from spyglass.common.common_interval import interval_list_intersect
from spyglass.common import IntervalList
key = {}
key["interval_list_name"] = "02_r1"
key["nwb_file_name"] = nwb_copy_file_name
interval = (
IntervalList
& {
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": key["interval_list_name"],
}
).fetch1("valid_times")
valid_ephys_times = (
IntervalList
& {
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": "raw data valid times",
}
).fetch1("valid_times")
position_interval_names = (
IntervalPositionInfo
& {
"nwb_file_name": key["nwb_file_name"],
"position_info_param_name": "default_decoding",
}
).fetch("interval_list_name")
valid_pos_times = [
(
IntervalList
& {
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": pos_interval_name,
}
).fetch1("valid_times")
for pos_interval_name in position_interval_names
]
intersect_interval = interval_list_intersect(
interval_list_intersect(interval, valid_ephys_times), valid_pos_times[0]
)
valid_time_slice = slice(intersect_interval[0][0], intersect_interval[0][1])
valid_time_slice
Out[10]:
slice(1581886916.3153033, 1581888227.5987928, None)
In [11]:
Copied!
from replay_trajectory_classification import ClusterlessClassifier
from replay_trajectory_classification.environments import Environment
from replay_trajectory_classification.continuous_state_transitions import (
RandomWalk,
Uniform,
)
from spyglass.decoding.clusterless import ClusterlessClassifierParameters
marks = marks.sel(time=valid_time_slice)
position_info = position_info.loc[valid_time_slice]
parameters = (
ClusterlessClassifierParameters()
& {"classifier_param_name": "default_decoding_gpu"}
).fetch1()
parameters["classifier_params"]["clusterless_algorithm_params"] = {
"mark_std": 24.0,
"position_std": 3.0,
"block_size": int(2**13),
"disable_progress_bar": False,
"use_diffusion": False,
}
parameters["classifier_params"]["environments"][0] = Environment(place_bin_size=3.0)
import cupy as cp
with cp.cuda.Device(0):
classifier = ClusterlessClassifier(**parameters["classifier_params"])
classifier.fit(
position=position_info[["head_position_x", "head_position_y"]].values,
multiunits=marks.values,
**parameters["fit_params"]
)
results = classifier.predict(
multiunits=marks.values,
time=position_info.index,
**parameters["predict_params"]
)
logging.info("Done!")
from replay_trajectory_classification import ClusterlessClassifier
from replay_trajectory_classification.environments import Environment
from replay_trajectory_classification.continuous_state_transitions import (
RandomWalk,
Uniform,
)
from spyglass.decoding.clusterless import ClusterlessClassifierParameters
marks = marks.sel(time=valid_time_slice)
position_info = position_info.loc[valid_time_slice]
parameters = (
ClusterlessClassifierParameters()
& {"classifier_param_name": "default_decoding_gpu"}
).fetch1()
parameters["classifier_params"]["clusterless_algorithm_params"] = {
"mark_std": 24.0,
"position_std": 3.0,
"block_size": int(2**13),
"disable_progress_bar": False,
"use_diffusion": False,
}
parameters["classifier_params"]["environments"][0] = Environment(place_bin_size=3.0)
import cupy as cp
with cp.cuda.Device(0):
classifier = ClusterlessClassifier(**parameters["classifier_params"])
classifier.fit(
position=position_info[["head_position_x", "head_position_y"]].values,
multiunits=marks.values,
**parameters["fit_params"]
)
results = classifier.predict(
multiunits=marks.values,
time=position_info.index,
**parameters["predict_params"]
)
logging.info("Done!")
13-Sep-22 15:43:02 Fitting initial conditions... 13-Sep-22 15:43:02 Fitting continuous state transition... 13-Sep-22 15:43:04 Fitting discrete state transition 13-Sep-22 15:43:04 Fitting multiunits... 13-Sep-22 15:43:06 Estimating likelihood...
13-Sep-22 15:44:52 Estimating causal posterior... 13-Sep-22 15:48:34 Estimating acausal posterior... 13-Sep-22 16:00:30 Done!
In [16]:
Copied!
from spyglass.decoding.visualization import create_interactive_2D_decoding_figurl
view = create_interactive_2D_decoding_figurl(
position_info,
marks,
results,
classifier.environments[0].place_bin_size,,
position_name=["head_position_x", "head_position_y"],
head_direction_name="head_orientation",
speed_name="head_speed",
posterior_type="acausal_posterior",
sampling_frequency=500,
view_height=800,
)
from spyglass.decoding.visualization import create_interactive_2D_decoding_figurl
view = create_interactive_2D_decoding_figurl(
position_info,
marks,
results,
classifier.environments[0].place_bin_size,,
position_name=["head_position_x", "head_position_y"],
head_direction_name="head_orientation",
speed_name="head_speed",
posterior_type="acausal_posterior",
sampling_frequency=500,
view_height=800,
)
If you want to view the decode in the jupyter notebook, uncomment and run this next cell. This will create an interactive visualization locally:
In [30]:
Copied!
# view
# view
If you would like to create a visualization shareable in the cloud, run this next cell:
In [19]:
Copied!
view.url(label="2D Decode Example")
view.url(label="2D Decode Example")
Out[19]:
'https://figurl.org/f?v=gs://figurl/spikesortingview-9&d=sha1://7251ae9794d4dfd783fabd9b84826c86b1982f74&label=2D%20Decode%20Example'